Skip to content

Commit d194d7c

Browse files
authored
Support Struct field (#1619)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent a1f5a32 commit d194d7c

File tree

24 files changed

+1024
-209
lines changed

24 files changed

+1024
-209
lines changed

docker-compose.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ version: '3.5'
33
services:
44
standalone:
55
container_name: milvus-javasdk-standalone-1
6-
image: milvusdb/milvus:v2.6.1
6+
image: milvusdb/milvus:master-20250922-200ee4cb-amd64
77
command: [ "milvus", "run", "standalone" ]
88
environment:
99
- COMMON_STORAGETYPE=local
@@ -24,7 +24,7 @@ services:
2424

2525
standaloneslave:
2626
container_name: milvus-javasdk-standalone-2
27-
image: milvusdb/milvus:v2.6.1
27+
image: milvusdb/milvus:master-20250922-200ee4cb-amd64
2828
command: [ "milvus", "run", "standalone" ]
2929
environment:
3030
- COMMON_STORAGETYPE=local

examples/src/main/java/io/milvus/v1/CommonUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ public static SortedMap<Long, Float> generateSparseVector() {
300300
Random ran = new Random();
301301
SortedMap<Long, Float> sparse = new TreeMap<>();
302302
int dim = ran.nextInt(10) + 10;
303-
for (int i = 0; i < dim; ++i) {
303+
while (sparse.size() < dim) {
304304
sparse.put((long)ran.nextInt(1000000), ran.nextFloat());
305305
}
306306
return sparse;

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BulkWriter.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@
3333
import io.milvus.bulkwriter.writer.ParquetFileWriter;
3434
import io.milvus.common.utils.ExceptionUtils;
3535
import io.milvus.common.utils.Float16Utils;
36-
import io.milvus.grpc.FieldSchema;
37-
import io.milvus.param.ParamUtils;
3836
import io.milvus.v2.common.DataType;
3937
import io.milvus.v2.service.collection.request.CreateCollectionReq;
40-
import io.milvus.v2.utils.SchemaUtils;
38+
import io.milvus.v2.utils.DataUtils;
4139
import org.apache.commons.collections4.CollectionUtils;
4240
import org.apache.commons.lang3.tuple.Pair;
4341
import org.slf4j.Logger;
@@ -362,8 +360,7 @@ protected Map<String, Object> verifyRow(JsonObject row) {
362360
}
363361

364362
private Pair<Object, Integer> verifyVector(JsonElement object, CreateCollectionReq.FieldSchema field) {
365-
FieldSchema grpcField = SchemaUtils.convertToGrpcFieldSchema(field);
366-
Object vector = ParamUtils.checkFieldValue(ParamUtils.ConvertField(grpcField), object);
363+
Object vector = DataUtils.checkFieldValue(field, object);
367364
io.milvus.v2.common.DataType dataType = field.getDataType();
368365
switch (dataType) {
369366
case FloatVector:
@@ -396,8 +393,7 @@ private Pair<Object, Integer> verifyVarchar(JsonElement object, CreateCollection
396393
return Pair.of(null, 0);
397394
}
398395

399-
FieldSchema grpcField = SchemaUtils.convertToGrpcFieldSchema(field);
400-
Object varchar = ParamUtils.checkFieldValue(ParamUtils.ConvertField(grpcField), object);
396+
Object varchar = DataUtils.checkFieldValue(field, object);
401397
return Pair.of(varchar, String.valueOf(varchar).length());
402398
}
403399

@@ -411,8 +407,7 @@ private Pair<Object, Integer> verifyJSON(JsonElement object, CreateCollectionReq
411407
}
412408

413409
private Pair<Object, Integer> verifyArray(JsonElement object, CreateCollectionReq.FieldSchema field) {
414-
FieldSchema grpcField = SchemaUtils.convertToGrpcFieldSchema(field);
415-
Object array = ParamUtils.checkFieldValue(ParamUtils.ConvertField(grpcField), object);
410+
Object array = DataUtils.checkFieldValue(field, object);
416411
if (array == null) {
417412
return Pair.of(null, 0);
418413
}

sdk-bulkwriter/src/test/java/io/milvus/bulkwriter/TestUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public List<ByteBuffer> generateBFloat16Vectors(int count) {
9191
public SortedMap<Long, Float> generateSparseVector() {
9292
SortedMap<Long, Float> sparse = new TreeMap<>();
9393
int dim = RANDOM.nextInt(10) + 10;
94-
for (int i = 0; i < dim; ++i) {
94+
while (sparse.size() < dim) {
9595
sparse.put((long) RANDOM.nextInt(1000000), RANDOM.nextFloat());
9696
}
9797
return sparse;

sdk-core/src/main/java/io/milvus/param/ParamUtils.java

Lines changed: 48 additions & 52 deletions
Large diffs are not rendered by default.

sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java

Lines changed: 140 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@
3030

3131
import java.nio.ByteBuffer;
3232
import java.nio.ByteOrder;
33-
import java.util.ArrayList;
34-
import java.util.List;
35-
import java.util.SortedMap;
33+
import java.util.*;
3634
import java.util.stream.Collectors;
3735

3836
import com.google.protobuf.ByteString;
@@ -72,7 +70,11 @@ public int getDim() throws IllegalResponseException {
7270
if (!isVectorField()) {
7371
throw new IllegalResponseException("Not a vector field");
7472
}
75-
return (int) fieldData.getVectors().getDim();
73+
return getDimInternal(fieldData.getVectors());
74+
}
75+
76+
private int getDimInternal(VectorField vector) {
77+
return (int) vector.getDim();
7678
}
7779

7880
// this method returns bytes size of each vector according to vector type
@@ -106,16 +108,16 @@ private int checkDim(DataType dt, ByteString data, int dim) {
106108
return 0;
107109
}
108110

109-
private ByteString getVectorBytes(FieldData fieldData, DataType dt) {
111+
private ByteString getVectorBytes(VectorField vd, DataType dt) {
110112
ByteString data;
111113
if (dt == DataType.BinaryVector) {
112-
data = fieldData.getVectors().getBinaryVector();
114+
data = vd.getBinaryVector();
113115
} else if (dt == DataType.Float16Vector) {
114-
data = fieldData.getVectors().getFloat16Vector();
116+
data = vd.getFloat16Vector();
115117
} else if (dt == DataType.BFloat16Vector) {
116-
data = fieldData.getVectors().getBfloat16Vector();
118+
data = vd.getBfloat16Vector();
117119
} else if (dt == DataType.Int8Vector) {
118-
data = fieldData.getVectors().getInt8Vector();
120+
data = vd.getInt8Vector();
119121
} else {
120122
String msg = String.format("Unsupported data type %s returned by FieldData", dt.name());
121123
throw new IllegalResponseException(msg);
@@ -148,7 +150,7 @@ public long getRowCount() throws IllegalResponseException {
148150
case BFloat16Vector:
149151
case Int8Vector: {
150152
int dim = getDim();
151-
ByteString data = getVectorBytes(fieldData, dt);
153+
ByteString data = getVectorBytes(fieldData.getVectors(), dt);
152154
int bytePerVec = checkDim(dt, data, dim);
153155

154156
return data.size()/bytePerVec;
@@ -176,6 +178,20 @@ public long getRowCount() throws IllegalResponseException {
176178
return fieldData.getScalars().getJsonData().getDataCount();
177179
case Array:
178180
return fieldData.getScalars().getArrayData().getDataCount();
181+
case ArrayOfStruct: {
182+
List<FieldData> structData = fieldData.getStructArrays().getFieldsList();
183+
for (FieldData fd : structData) {
184+
if (fd.getType() == DataType.Array) {
185+
return fd.getScalars().getArrayData().getDataCount();
186+
} else if (fd.getType() == DataType.ArrayOfVector) {
187+
FieldDataWrapper tempWrapper = new FieldDataWrapper(fd);
188+
return tempWrapper.getRowCount();
189+
}
190+
}
191+
}
192+
case ArrayOfVector: {
193+
return fieldData.getVectors().getVectorArray().getDataCount();
194+
}
179195
default:
180196
throw new IllegalResponseException("Unsupported data type returned by FieldData");
181197
}
@@ -194,6 +210,7 @@ public long getRowCount() throws IllegalResponseException {
194210
* Varchar field returns List of String
195211
* Array field returns List of List
196212
* JSON field returns List of String;
213+
* Struct field returns List of List<Map<String, Object>>
197214
* etc.
198215
*
199216
* Throws {@link IllegalResponseException} if the field type is illegal.
@@ -211,10 +228,51 @@ public List<?> getFieldData() throws IllegalResponseException {
211228

212229
private List<?> getFieldDataInternal() throws IllegalResponseException {
213230
DataType dt = fieldData.getType();
231+
switch (dt) {
232+
case FloatVector:
233+
case BinaryVector:
234+
case Float16Vector:
235+
case BFloat16Vector:
236+
case Int8Vector:
237+
case SparseFloatVector:
238+
return getVectorData(dt, fieldData.getVectors());
239+
case Array:
240+
case Int64:
241+
case Int32:
242+
case Int16:
243+
case Int8:
244+
case Bool:
245+
case Float:
246+
case Double:
247+
case VarChar:
248+
case String:
249+
case JSON:
250+
return getScalarData(dt, fieldData.getScalars(), fieldData.getValidDataList());
251+
case ArrayOfStruct:
252+
return getStructData(fieldData.getStructArrays(), fieldData.getFieldName());
253+
default:
254+
throw new IllegalResponseException("Unsupported data type returned by FieldData");
255+
}
256+
}
257+
258+
private List<?> setNoneData(List<?> data, List<Boolean> validData) {
259+
if (validData != null && validData.size() == data.size()) {
260+
List<?> newData = new ArrayList<>(data); // copy the list since the data is come from grpc is not mutable
261+
for (int i = 0; i < validData.size(); i++) {
262+
if (validData.get(i) == Boolean.FALSE) {
263+
newData.set(i, null);
264+
}
265+
}
266+
return newData;
267+
}
268+
return data;
269+
}
270+
271+
private List<?> getVectorData(DataType dt, VectorField vector) {
214272
switch (dt) {
215273
case FloatVector: {
216-
int dim = getDim();
217-
List<Float> data = fieldData.getVectors().getFloatVector().getDataList();
274+
int dim = getDimInternal(vector);
275+
List<Float> data = vector.getFloatVector().getDataList();
218276
if (data.size() % dim != 0) {
219277
String msg = String.format("Returned float vector data array size %d doesn't match dimension %d",
220278
data.size(), dim);
@@ -232,10 +290,10 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
232290
case Float16Vector:
233291
case BFloat16Vector:
234292
case Int8Vector: {
235-
int dim = getDim();
236-
ByteString data = getVectorBytes(fieldData, dt);
293+
int dim = getDimInternal(vector);
294+
ByteString data = getVectorBytes(vector, dt);
237295
int bytePerVec = checkDim(dt, data, dim);
238-
int count = data.size()/bytePerVec;
296+
int count = data.size() / bytePerVec;
239297
List<ByteBuffer> packData = new ArrayList<>();
240298
for (int i = 0; i < count; ++i) {
241299
ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
@@ -252,7 +310,7 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
252310
// in Java sdk, each sparse vector is pairs of long+float
253311
// in server side, each sparse vector is stored as uint+float (8 bytes)
254312
// don't use sparseArray.getDim() because the dim is the max index of each rows
255-
SparseFloatArray sparseArray = fieldData.getVectors().getSparseFloatVector();
313+
SparseFloatArray sparseArray = vector.getSparseFloatVector();
256314
List<SortedMap<Long, Float>> packData = new ArrayList<>();
257315
for (int i = 0; i < sparseArray.getContentsCount(); ++i) {
258316
ByteString bs = sparseArray.getContents(i);
@@ -262,34 +320,9 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
262320
}
263321
return packData;
264322
}
265-
case Array:
266-
case Int64:
267-
case Int32:
268-
case Int16:
269-
case Int8:
270-
case Bool:
271-
case Float:
272-
case Double:
273-
case VarChar:
274-
case String:
275-
case JSON:
276-
return getScalarData(dt, fieldData.getScalars(), fieldData.getValidDataList());
277323
default:
278-
throw new IllegalResponseException("Unsupported data type returned by FieldData");
279-
}
280-
}
281-
282-
private List<?> setNoneData(List<?> data, List<Boolean> validData) {
283-
if (validData != null && validData.size() == data.size()) {
284-
List<?> newData = new ArrayList<>(data); // copy the list since the data is come from grpc is not mutable
285-
for (int i = 0; i < validData.size(); i++) {
286-
if (validData.get(i) == Boolean.FALSE) {
287-
newData.set(i, null);
288-
}
289-
}
290-
return newData;
324+
return new ArrayList<>();
291325
}
292-
return data;
293326
}
294327

295328
private List<?> getScalarData(DataType dt, ScalarField scalar, List<Boolean> validData) {
@@ -315,7 +348,7 @@ private List<?> getScalarData(DataType dt, ScalarField scalar, List<Boolean> val
315348
return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList());
316349
case Array:
317350
List<List<?>> array = new ArrayList<>();
318-
ArrayArray arrArray = fieldData.getScalars().getArrayData();
351+
ArrayArray arrArray = scalar.getArrayData();
319352
boolean nullable = validData != null && validData.size() == arrArray.getDataCount();
320353
for (int i = 0; i < arrArray.getDataCount(); i++) {
321354
if (nullable && validData.get(i) == Boolean.FALSE) {
@@ -331,6 +364,70 @@ private List<?> getScalarData(DataType dt, ScalarField scalar, List<Boolean> val
331364
}
332365
}
333366

367+
private List<?> getStructData(StructArrayField field, String fieldName) {
368+
List<List<Map<String, Object>>> packData = new ArrayList<>();
369+
if (field.getFieldsCount() == 0) {
370+
return packData;
371+
}
372+
373+
// read column data from FieldData
374+
// for a struct with two sub-fields "int" and "emb", search with nq=2, topk=3
375+
// the column data is like this:
376+
// {
377+
// "int": [[x1, x2], [x1, x2, x3], [x1], [x1, x2], [x1, x2, x3], [x1]],
378+
// "emb": [[emb1, emb2], [emb1, emb2, emb3], [emb1], [emb1m emb2], [emb1, emb2, emb3], [emb1]],
379+
// }
380+
Map<String, List<List<?>>> columnsData = new HashMap<>();
381+
int rowCount = 0;
382+
for (FieldData fd : field.getFieldsList()) {
383+
List<List<?>> column = new ArrayList<>();
384+
if (fd.getType() == DataType.Array) {
385+
column = (List<List<?>>) getScalarData(fd.getType(), fd.getScalars(), fd.getValidDataList());
386+
columnsData.put(fd.getFieldName(), column);
387+
rowCount = column.size();
388+
} else if (fd.getType() == DataType.ArrayOfVector) {
389+
VectorArray vecArr = fd.getVectors().getVectorArray();
390+
for (VectorField vf : vecArr.getDataList()) {
391+
List<?> vector = getVectorData(vecArr.getElementType(), vf);
392+
column.add(vector);
393+
}
394+
rowCount = column.size();
395+
columnsData.put(fd.getFieldName(), column);
396+
} else {
397+
throw new IllegalResponseException("Unsupported data type returned by StructArrayField");
398+
}
399+
}
400+
401+
// convert column data into struct list, eventually, the packData is like this:
402+
// [
403+
// [{x1, emb1}, {x2, emb2}],
404+
// [{x1, emb1}, {x2, emb2}, {x3, emb3}],
405+
// [{x1, emb1}],
406+
// [{x1, emb1}, {x2, emb2}],
407+
// [{x1, emb1}, {x2, emb2}, {x3, emb3}],
408+
// [{x1, emb1}]
409+
// ]
410+
for (int i = 0; i < rowCount; i++) {
411+
int elementCount = 0;
412+
Map<String, List<?>> rowColumn = new HashMap<>();
413+
for (String key : columnsData.keySet()) {
414+
List<?> val = columnsData.get(key).get(i);
415+
rowColumn.put(key, val);
416+
elementCount = val.size();
417+
}
418+
419+
List<Map<String, Object>> structs = new ArrayList<>();
420+
for (int k = 0; k < elementCount; k++) {
421+
Map<String, Object> struct = new HashMap<>();
422+
int finalK = k;
423+
rowColumn.forEach((key, val)->struct.put(key, val.get(finalK)));
424+
structs.add(struct);
425+
}
426+
packData.add(structs);
427+
}
428+
return packData;
429+
}
430+
334431
public Integer getAsInt(int index, String paramName) throws IllegalResponseException {
335432
if (isJsonField()) {
336433
String result = getAsString(index, paramName);

sdk-core/src/main/java/io/milvus/v2/common/DataType.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ public enum DataType {
4545
Float16Vector(102),
4646
BFloat16Vector(103),
4747
SparseFloatVector(104),
48-
Int8Vector(105);
48+
Int8Vector(105),
49+
50+
Struct(201);
4951

5052
private final int code;
5153
DataType(int code) {

sdk-core/src/main/java/io/milvus/v2/common/IndexParam.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ public enum MetricType {
5252

5353
// Only for sparse vector with BM25
5454
BM25,
55+
56+
// Only for struct vector
57+
MAX_SIM,
5558
;
5659
}
5760

@@ -94,7 +97,10 @@ public enum IndexType {
9497
SPARSE_INVERTED_INDEX(300),
9598
// From Milvus 2.5.4 onward, SPARSE_WAND is being deprecated. Instead, it is recommended to
9699
// use "inverted_index_algo": "DAAT_WAND" for equivalency while maintaining compatibility.
97-
SPARSE_WAND(301)
100+
SPARSE_WAND(301),
101+
102+
// Only for struct vector
103+
EMB_LIST_HNSW(401),
98104
;
99105

100106
private final String name;

0 commit comments

Comments
 (0)