Skip to content

Commit

Permalink
enhance: [GoSDK] Add Slice method for Vector Columns (milvus-io#37951)
Browse files Browse the repository at this point in the history
Related to milvus-io#37768

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
  • Loading branch information
congqixia authored and JsDove committed Nov 26, 2024
1 parent 9cb6052 commit 6124d23
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 3 deletions.
6 changes: 6 additions & 0 deletions client/column/sparse.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,9 @@ func (c *ColumnSparseFloatVector) FieldData() *schemapb.FieldData {
vectors.Dim = int64(max.Dim())
return fd
}

func (c *ColumnSparseFloatVector) Slice(start, end int) Column {
return &ColumnSparseFloatVector{
vectorBase: c.vectorBase.slice(start, end),
}
}
9 changes: 9 additions & 0 deletions client/column/sparse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,13 @@ func TestColumnSparseEmbedding(t *testing.T) {
assert.Equal(t, v, getV)
}
})

t.Run("test_column_slice", func(t *testing.T) {
l := rand.Intn(columnLen)
sliced := column.Slice(0, l)
slicedColumn, ok := sliced.(*ColumnSparseFloatVector)
if assert.True(t, ok) {
assert.Equal(t, column.Data()[:l], slicedColumn.Data())
}
})
}
31 changes: 31 additions & 0 deletions client/column/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ func (b *vectorBase[T]) FieldData() *schemapb.FieldData {
return fd
}

func (b *vectorBase[T]) slice(start, end int) *vectorBase[T] {
return &vectorBase[T]{
genericColumnBase: b.genericColumnBase.slice(start, end),
dim: b.dim,
}
}

func newVectorBase[T entity.Vector](fieldName string, dim int, vectors []T, fieldType entity.FieldType) *vectorBase[T] {
return &vectorBase[T]{
genericColumnBase: &genericColumnBase[T]{
Expand Down Expand Up @@ -78,6 +85,12 @@ func (c *ColumnFloatVector) AppendValue(i interface{}) error {
return nil
}

func (c *ColumnFloatVector) Slice(start, end int) Column {
return &ColumnFloatVector{
vectorBase: c.vectorBase.slice(start, end),
}
}

/* binary vector */

type ColumnBinaryVector struct {
Expand Down Expand Up @@ -105,6 +118,12 @@ func (c *ColumnBinaryVector) AppendValue(i interface{}) error {
return nil
}

func (c *ColumnBinaryVector) Slice(start, end int) Column {
return &ColumnBinaryVector{
vectorBase: c.vectorBase.slice(start, end),
}
}

/* fp16 vector */

type ColumnFloat16Vector struct {
Expand Down Expand Up @@ -132,6 +151,12 @@ func (c *ColumnFloat16Vector) AppendValue(i interface{}) error {
return nil
}

func (c *ColumnFloat16Vector) Slice(start, end int) Column {
return &ColumnFloat16Vector{
vectorBase: c.vectorBase.slice(start, end),
}
}

type ColumnBFloat16Vector struct {
*vectorBase[entity.BFloat16Vector]
}
Expand All @@ -156,3 +181,9 @@ func (c *ColumnBFloat16Vector) AppendValue(i interface{}) error {
}
return nil
}

func (c *ColumnBFloat16Vector) Slice(start, end int) Column {
return &ColumnBFloat16Vector{
vectorBase: c.vectorBase.slice(start, end),
}
}
90 changes: 90 additions & 0 deletions client/column/vector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,96 @@ func (s *VectorSuite) TestBasic() {
})
}

func (s *VectorSuite) TestSlice() {
s.Run("float_vector", func() {
name := fmt.Sprintf("field_%d", rand.Intn(1000))
n := 100
dim := rand.Intn(10) + 2
data := make([][]float32, 0, n)
for i := 0; i < n; i++ {
row := lo.RepeatBy(dim, func(i int) float32 {
return rand.Float32()
})
data = append(data, row)
}
column := NewColumnFloatVector(name, dim, data)

l := rand.Intn(n)
sliced := column.Slice(0, l)
slicedColumn, ok := sliced.(*ColumnFloatVector)
if s.True(ok) {
s.Equal(dim, slicedColumn.Dim())
s.Equal(lo.Map(data[:l], func(row []float32, _ int) entity.FloatVector { return entity.FloatVector(row) }), slicedColumn.Data())
}
})

s.Run("binary_vector", func() {
name := fmt.Sprintf("field_%d", rand.Intn(1000))
n := 100
dim := (rand.Intn(10) + 1) * 8
data := make([][]byte, 0, n)
for i := 0; i < n; i++ {
row := lo.RepeatBy(dim/8, func(i int) byte {
return byte(rand.Intn(math.MaxUint8))
})
data = append(data, row)
}
column := NewColumnBinaryVector(name, dim, data)

l := rand.Intn(n)
sliced := column.Slice(0, l)
slicedColumn, ok := sliced.(*ColumnBinaryVector)
if s.True(ok) {
s.Equal(dim, slicedColumn.Dim())
s.Equal(lo.Map(data[:l], func(row []byte, _ int) entity.BinaryVector { return entity.BinaryVector(row) }), slicedColumn.Data())
}
})

s.Run("fp16_vector", func() {
name := fmt.Sprintf("field_%d", rand.Intn(1000))
n := 3
dim := rand.Intn(10) + 1
data := make([][]byte, 0, n)
for i := 0; i < n; i++ {
row := lo.RepeatBy(dim*2, func(i int) byte {
return byte(rand.Intn(math.MaxUint8))
})
data = append(data, row)
}
column := NewColumnFloat16Vector(name, dim, data)

l := rand.Intn(n)
sliced := column.Slice(0, l)
slicedColumn, ok := sliced.(*ColumnFloat16Vector)
if s.True(ok) {
s.Equal(dim, slicedColumn.Dim())
s.Equal(lo.Map(data[:l], func(row []byte, _ int) entity.Float16Vector { return entity.Float16Vector(row) }), slicedColumn.Data())
}
})

s.Run("bf16_vector", func() {
name := fmt.Sprintf("field_%d", rand.Intn(1000))
n := 3
dim := rand.Intn(10) + 1
data := make([][]byte, 0, n)
for i := 0; i < n; i++ {
row := lo.RepeatBy(dim*2, func(i int) byte {
return byte(rand.Intn(math.MaxUint8))
})
data = append(data, row)
}
column := NewColumnBFloat16Vector(name, dim, data)

l := rand.Intn(n)
sliced := column.Slice(0, l)
slicedColumn, ok := sliced.(*ColumnBFloat16Vector)
if s.True(ok) {
s.Equal(dim, slicedColumn.Dim())
s.Equal(lo.Map(data[:l], func(row []byte, _ int) entity.BFloat16Vector { return entity.BFloat16Vector(row) }), slicedColumn.Data())
}
})
}

func TestVectors(t *testing.T) {
suite.Run(t, new(VectorSuite))
}
2 changes: 1 addition & 1 deletion client/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/blang/semver/v4 v4.0.0
github.com/cockroachdb/errors v1.9.1
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241108105827-266fb751b620
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69
github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb
github.com/quasilyte/go-ruleguard/dsl v0.3.22
github.com/samber/lo v1.27.0
Expand Down
4 changes: 2 additions & 2 deletions client/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241108105827-266fb751b620 h1:0IWUDtDloift7cQHalhdjuVkL/3qSeiXFqR7MofZBkg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241108105827-266fb751b620/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69 h1:Qt0Bv2Fum3EX3OlkuQYHJINBzeU4oEuHy2lXSfB/gZw=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241120015424-93892e628c69/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb h1:lMyIrG03agASB88AAwnk+NOU9V33lcBdtub/ZEv6IQU=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241111021426-5e90f348fcbb/go.mod h1:w5nu1Z318AvgWQrGUYXaqLeVLu4JvCS/oYhxqctOZvU=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
Expand Down
59 changes: 59 additions & 0 deletions client/milvusclient/mock_milvus_server_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6124d23

Please sign in to comment.