3030
3131import java .nio .ByteBuffer ;
3232import java .nio .ByteOrder ;
33- import java .util .ArrayList ;
34- import java .util .List ;
35- import java .util .SortedMap ;
33+ import java .util .*;
3634import java .util .stream .Collectors ;
3735
3836import com .google .protobuf .ByteString ;
@@ -72,7 +70,11 @@ public int getDim() throws IllegalResponseException {
7270 if (!isVectorField ()) {
7371 throw new IllegalResponseException ("Not a vector field" );
7472 }
75- return (int ) fieldData .getVectors ().getDim ();
73+ return getDimInternal (fieldData .getVectors ());
74+ }
75+
76+ private int getDimInternal (VectorField vector ) {
77+ return (int ) vector .getDim ();
7678 }
7779
7880 // this method returns bytes size of each vector according to vector type
@@ -106,16 +108,16 @@ private int checkDim(DataType dt, ByteString data, int dim) {
106108 return 0 ;
107109 }
108110
109- private ByteString getVectorBytes (FieldData fieldData , DataType dt ) {
111+ private ByteString getVectorBytes (VectorField vd , DataType dt ) {
110112 ByteString data ;
111113 if (dt == DataType .BinaryVector ) {
112- data = fieldData . getVectors () .getBinaryVector ();
114+ data = vd .getBinaryVector ();
113115 } else if (dt == DataType .Float16Vector ) {
114- data = fieldData . getVectors () .getFloat16Vector ();
116+ data = vd .getFloat16Vector ();
115117 } else if (dt == DataType .BFloat16Vector ) {
116- data = fieldData . getVectors () .getBfloat16Vector ();
118+ data = vd .getBfloat16Vector ();
117119 } else if (dt == DataType .Int8Vector ) {
118- data = fieldData . getVectors () .getInt8Vector ();
120+ data = vd .getInt8Vector ();
119121 } else {
120122 String msg = String .format ("Unsupported data type %s returned by FieldData" , dt .name ());
121123 throw new IllegalResponseException (msg );
@@ -148,7 +150,7 @@ public long getRowCount() throws IllegalResponseException {
148150 case BFloat16Vector :
149151 case Int8Vector : {
150152 int dim = getDim ();
151- ByteString data = getVectorBytes (fieldData , dt );
153+ ByteString data = getVectorBytes (fieldData . getVectors () , dt );
152154 int bytePerVec = checkDim (dt , data , dim );
153155
154156 return data .size ()/bytePerVec ;
@@ -176,6 +178,20 @@ public long getRowCount() throws IllegalResponseException {
176178 return fieldData .getScalars ().getJsonData ().getDataCount ();
177179 case Array :
178180 return fieldData .getScalars ().getArrayData ().getDataCount ();
181+ case ArrayOfStruct : {
182+ List <FieldData > structData = fieldData .getStructArrays ().getFieldsList ();
183+ for (FieldData fd : structData ) {
184+ if (fd .getType () == DataType .Array ) {
185+ return fd .getScalars ().getArrayData ().getDataCount ();
186+ } else if (fd .getType () == DataType .ArrayOfVector ) {
187+ FieldDataWrapper tempWrapper = new FieldDataWrapper (fd );
188+ return tempWrapper .getRowCount ();
189+ }
190+ }
191+ }
192+ case ArrayOfVector : {
193+ return fieldData .getVectors ().getVectorArray ().getDataCount ();
194+ }
179195 default :
180196 throw new IllegalResponseException ("Unsupported data type returned by FieldData" );
181197 }
@@ -194,6 +210,7 @@ public long getRowCount() throws IllegalResponseException {
194210 * Varchar field returns List of String
195211 * Array field returns List of List
196212 * JSON field returns List of String;
213+ * Struct field returns List of List<Map<String, Object>>
197214 * etc.
198215 *
199216 * Throws {@link IllegalResponseException} if the field type is illegal.
@@ -211,10 +228,51 @@ public List<?> getFieldData() throws IllegalResponseException {
211228
212229 private List <?> getFieldDataInternal () throws IllegalResponseException {
213230 DataType dt = fieldData .getType ();
231+ switch (dt ) {
232+ case FloatVector :
233+ case BinaryVector :
234+ case Float16Vector :
235+ case BFloat16Vector :
236+ case Int8Vector :
237+ case SparseFloatVector :
238+ return getVectorData (dt , fieldData .getVectors ());
239+ case Array :
240+ case Int64 :
241+ case Int32 :
242+ case Int16 :
243+ case Int8 :
244+ case Bool :
245+ case Float :
246+ case Double :
247+ case VarChar :
248+ case String :
249+ case JSON :
250+ return getScalarData (dt , fieldData .getScalars (), fieldData .getValidDataList ());
251+ case ArrayOfStruct :
252+ return getStructData (fieldData .getStructArrays (), fieldData .getFieldName ());
253+ default :
254+ throw new IllegalResponseException ("Unsupported data type returned by FieldData" );
255+ }
256+ }
257+
258+ private List <?> setNoneData (List <?> data , List <Boolean > validData ) {
259+ if (validData != null && validData .size () == data .size ()) {
260+ List <?> newData = new ArrayList <>(data ); // copy the list since the data is come from grpc is not mutable
261+ for (int i = 0 ; i < validData .size (); i ++) {
262+ if (validData .get (i ) == Boolean .FALSE ) {
263+ newData .set (i , null );
264+ }
265+ }
266+ return newData ;
267+ }
268+ return data ;
269+ }
270+
271+ private List <?> getVectorData (DataType dt , VectorField vector ) {
214272 switch (dt ) {
215273 case FloatVector : {
216- int dim = getDim ( );
217- List <Float > data = fieldData . getVectors () .getFloatVector ().getDataList ();
274+ int dim = getDimInternal ( vector );
275+ List <Float > data = vector .getFloatVector ().getDataList ();
218276 if (data .size () % dim != 0 ) {
219277 String msg = String .format ("Returned float vector data array size %d doesn't match dimension %d" ,
220278 data .size (), dim );
@@ -232,10 +290,10 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
232290 case Float16Vector :
233291 case BFloat16Vector :
234292 case Int8Vector : {
235- int dim = getDim ( );
236- ByteString data = getVectorBytes (fieldData , dt );
293+ int dim = getDimInternal ( vector );
294+ ByteString data = getVectorBytes (vector , dt );
237295 int bytePerVec = checkDim (dt , data , dim );
238- int count = data .size ()/ bytePerVec ;
296+ int count = data .size () / bytePerVec ;
239297 List <ByteBuffer > packData = new ArrayList <>();
240298 for (int i = 0 ; i < count ; ++i ) {
241299 ByteBuffer bf = ByteBuffer .allocate (bytePerVec );
@@ -252,7 +310,7 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
252310 // in Java sdk, each sparse vector is pairs of long+float
253311 // in server side, each sparse vector is stored as uint+float (8 bytes)
254312 // don't use sparseArray.getDim() because the dim is the max index of each rows
255- SparseFloatArray sparseArray = fieldData . getVectors () .getSparseFloatVector ();
313+ SparseFloatArray sparseArray = vector .getSparseFloatVector ();
256314 List <SortedMap <Long , Float >> packData = new ArrayList <>();
257315 for (int i = 0 ; i < sparseArray .getContentsCount (); ++i ) {
258316 ByteString bs = sparseArray .getContents (i );
@@ -262,34 +320,9 @@ private List<?> getFieldDataInternal() throws IllegalResponseException {
262320 }
263321 return packData ;
264322 }
265- case Array :
266- case Int64 :
267- case Int32 :
268- case Int16 :
269- case Int8 :
270- case Bool :
271- case Float :
272- case Double :
273- case VarChar :
274- case String :
275- case JSON :
276- return getScalarData (dt , fieldData .getScalars (), fieldData .getValidDataList ());
277323 default :
278- throw new IllegalResponseException ("Unsupported data type returned by FieldData" );
279- }
280- }
281-
282- private List <?> setNoneData (List <?> data , List <Boolean > validData ) {
283- if (validData != null && validData .size () == data .size ()) {
284- List <?> newData = new ArrayList <>(data ); // copy the list since the data is come from grpc is not mutable
285- for (int i = 0 ; i < validData .size (); i ++) {
286- if (validData .get (i ) == Boolean .FALSE ) {
287- newData .set (i , null );
288- }
289- }
290- return newData ;
324+ return new ArrayList <>();
291325 }
292- return data ;
293326 }
294327
295328 private List <?> getScalarData (DataType dt , ScalarField scalar , List <Boolean > validData ) {
@@ -315,7 +348,7 @@ private List<?> getScalarData(DataType dt, ScalarField scalar, List<Boolean> val
315348 return dataList .stream ().map (ByteString ::toStringUtf8 ).collect (Collectors .toList ());
316349 case Array :
317350 List <List <?>> array = new ArrayList <>();
318- ArrayArray arrArray = fieldData . getScalars () .getArrayData ();
351+ ArrayArray arrArray = scalar .getArrayData ();
319352 boolean nullable = validData != null && validData .size () == arrArray .getDataCount ();
320353 for (int i = 0 ; i < arrArray .getDataCount (); i ++) {
321354 if (nullable && validData .get (i ) == Boolean .FALSE ) {
@@ -331,6 +364,70 @@ private List<?> getScalarData(DataType dt, ScalarField scalar, List<Boolean> val
331364 }
332365 }
333366
367+ private List <?> getStructData (StructArrayField field , String fieldName ) {
368+ List <List <Map <String , Object >>> packData = new ArrayList <>();
369+ if (field .getFieldsCount () == 0 ) {
370+ return packData ;
371+ }
372+
373+ // read column data from FieldData
374+ // for a struct with two sub-fields "int" and "emb", search with nq=2, topk=3
375+ // the column data is like this:
376+ // {
377+ // "int": [[x1, x2], [x1, x2, x3], [x1], [x1, x2], [x1, x2, x3], [x1]],
378+ // "emb": [[emb1, emb2], [emb1, emb2, emb3], [emb1], [emb1m emb2], [emb1, emb2, emb3], [emb1]],
379+ // }
380+ Map <String , List <List <?>>> columnsData = new HashMap <>();
381+ int rowCount = 0 ;
382+ for (FieldData fd : field .getFieldsList ()) {
383+ List <List <?>> column = new ArrayList <>();
384+ if (fd .getType () == DataType .Array ) {
385+ column = (List <List <?>>) getScalarData (fd .getType (), fd .getScalars (), fd .getValidDataList ());
386+ columnsData .put (fd .getFieldName (), column );
387+ rowCount = column .size ();
388+ } else if (fd .getType () == DataType .ArrayOfVector ) {
389+ VectorArray vecArr = fd .getVectors ().getVectorArray ();
390+ for (VectorField vf : vecArr .getDataList ()) {
391+ List <?> vector = getVectorData (vecArr .getElementType (), vf );
392+ column .add (vector );
393+ }
394+ rowCount = column .size ();
395+ columnsData .put (fd .getFieldName (), column );
396+ } else {
397+ throw new IllegalResponseException ("Unsupported data type returned by StructArrayField" );
398+ }
399+ }
400+
401+ // convert column data into struct list, eventually, the packData is like this:
402+ // [
403+ // [{x1, emb1}, {x2, emb2}],
404+ // [{x1, emb1}, {x2, emb2}, {x3, emb3}],
405+ // [{x1, emb1}],
406+ // [{x1, emb1}, {x2, emb2}],
407+ // [{x1, emb1}, {x2, emb2}, {x3, emb3}],
408+ // [{x1, emb1}]
409+ // ]
410+ for (int i = 0 ; i < rowCount ; i ++) {
411+ int elementCount = 0 ;
412+ Map <String , List <?>> rowColumn = new HashMap <>();
413+ for (String key : columnsData .keySet ()) {
414+ List <?> val = columnsData .get (key ).get (i );
415+ rowColumn .put (key , val );
416+ elementCount = val .size ();
417+ }
418+
419+ List <Map <String , Object >> structs = new ArrayList <>();
420+ for (int k = 0 ; k < elementCount ; k ++) {
421+ Map <String , Object > struct = new HashMap <>();
422+ int finalK = k ;
423+ rowColumn .forEach ((key , val )->struct .put (key , val .get (finalK )));
424+ structs .add (struct );
425+ }
426+ packData .add (structs );
427+ }
428+ return packData ;
429+ }
430+
334431 public Integer getAsInt (int index , String paramName ) throws IllegalResponseException {
335432 if (isJsonField ()) {
336433 String result = getAsString (index , paramName );
0 commit comments