Skip to content

Commit 585c74e

Browse files
julienledemkou
authored andcommitted
ARROW-275: Add tests for UnionVector in Arrow File
Author: Julien Le Dem <julien@dremio.com> Closes apache#169 from julienledem/union_test and squashes the following commits: 120f504 [Julien Le Dem] ARROW-275: Add tests for UnionVector in Arrow File
1 parent c9116bb commit 585c74e

File tree

5 files changed

+127
-22
lines changed

5 files changed

+127
-22
lines changed

vector/src/main/codegen/templates/UnionReader.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ public void copyAsValue(UnionWriter writer) {
134134

135135
</#list>
136136

137+
public int size() {
138+
return getReaderForIndex(idx()).size();
139+
}
140+
137141
<#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first />
138142
<#assign uncappedName = name?uncap_first/>
139143
<#assign boxedType = (minor.boxedType!type.boxedType) />

vector/src/main/codegen/templates/UnionVector.java

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,6 @@
1515
* See the License for the specific language governing permissions and
1616
* limitations under the License.
1717
*/
18-
19-
import com.google.common.collect.ImmutableList;
20-
import com.google.flatbuffers.FlatBufferBuilder;
21-
import io.netty.buffer.ArrowBuf;
22-
import org.apache.arrow.flatbuf.Field;
23-
import org.apache.arrow.flatbuf.Type;
24-
import org.apache.arrow.flatbuf.Union;
25-
import org.apache.arrow.vector.ValueVector;
26-
import org.apache.arrow.vector.types.pojo.ArrowType;
27-
28-
import java.util.ArrayList;
2918
import java.util.List;
3019

3120
<@pp.dropOutputFile />
@@ -39,14 +28,17 @@
3928
<#include "/@includes/vv_imports.ftl" />
4029
import com.google.common.collect.ImmutableList;
4130
import java.util.ArrayList;
31+
import java.util.Collections;
4232
import java.util.Iterator;
33+
import org.apache.arrow.vector.BaseDataValueVector;
4334
import org.apache.arrow.vector.complex.impl.ComplexCopier;
4435
import org.apache.arrow.vector.util.CallBack;
4536
import org.apache.arrow.vector.schema.ArrowFieldNode;
4637

4738
import static org.apache.arrow.flatbuf.UnionMode.Sparse;
4839

4940

41+
5042
/*
5143
* This class is generated using freemarker and the ${.template_name} template.
5244
*/
@@ -81,13 +73,15 @@ public class UnionVector implements FieldVector {
8173
private ValueVector singleVector;
8274

8375
private final CallBack callBack;
76+
private final List<BufferBacked> innerVectors;
8477

8578
public UnionVector(String name, BufferAllocator allocator, CallBack callBack) {
8679
this.name = name;
8780
this.allocator = allocator;
8881
this.internalMap = new MapVector("internal", allocator, callBack);
8982
this.typeVector = new UInt1Vector("types", allocator);
9083
this.callBack = callBack;
84+
this.innerVectors = Collections.unmodifiableList(Arrays.<BufferBacked>asList(typeVector));
9185
}
9286

9387
public BufferAllocator getAllocator() {
@@ -101,30 +95,28 @@ public MinorType getMinorType() {
10195

10296
@Override
10397
public void initializeChildrenFromFields(List<Field> children) {
104-
getMap().initializeChildrenFromFields(children);
98+
internalMap.initializeChildrenFromFields(children);
10599
}
106100

107101
@Override
108102
public List<FieldVector> getChildrenFromFields() {
109-
return getMap().getChildrenFromFields();
103+
return internalMap.getChildrenFromFields();
110104
}
111105

112106
@Override
113107
public void loadFieldBuffers(ArrowFieldNode fieldNode, List<ArrowBuf> ownBuffers) {
114-
// TODO
115-
throw new UnsupportedOperationException();
108+
BaseDataValueVector.load(getFieldInnerVectors(), ownBuffers);
109+
this.valueCount = fieldNode.getLength();
116110
}
117111

118112
@Override
119113
public List<ArrowBuf> getFieldBuffers() {
120-
// TODO
121-
throw new UnsupportedOperationException();
114+
return BaseDataValueVector.unload(getFieldInnerVectors());
122115
}
123116

124117
@Override
125118
public List<BufferBacked> getFieldInnerVectors() {
126-
// TODO
127-
throw new UnsupportedOperationException();
119+
return this.innerVectors;
128120
}
129121

130122
public NullableMapVector getMap() {

vector/src/main/java/org/apache/arrow/vector/VectorLoader.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ public void load(ArrowRecordBatch recordBatch) {
7474
}
7575

7676
private void loadBuffers(FieldVector vector, Field field, Iterator<ArrowBuf> buffers, Iterator<ArrowFieldNode> nodes) {
77+
checkArgument(nodes.hasNext(),
78+
"no more field nodes for for field " + field + " and vector " + vector);
7779
ArrowFieldNode fieldNode = nodes.next();
7880
List<VectorLayout> typeLayout = field.getTypeLayout().getVectors();
7981
List<ArrowBuf> ownBuffers = new ArrayList<>(typeLayout.size());

vector/src/main/java/org/apache/arrow/vector/schema/TypeLayout.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ public static TypeLayout getTypeLayout(final ArrowType arrowType) {
8282
break;
8383
case UnionMode.Sparse:
8484
vectors = asList(
85-
validityVector(),
86-
typeVector()
85+
typeVector() // type of the value at the index or 0 if null
8786
);
8887
break;
8988
default:

vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ private void validateComplexContent(int count, NullableMapVector parent) {
266266
Assert.assertEquals(i % 3, rootReader.reader("list").size());
267267
NullableTimeStampHolder h = new NullableTimeStampHolder();
268268
rootReader.reader("map").reader("timestamp").read(h);
269-
Assert.assertEquals(i, h.value % COUNT);
269+
Assert.assertEquals(i, h.value);
270270
}
271271
}
272272

@@ -339,4 +339,112 @@ public void testWriteReadMultipleRBs() throws IOException {
339339
}
340340
}
341341

342+
@Test
343+
public void testWriteReadUnion() throws IOException {
344+
File file = new File("target/mytest_write_union.arrow");
345+
int count = COUNT;
346+
try (
347+
BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE);
348+
NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null)) {
349+
350+
writeUnionData(count, parent);
351+
352+
printVectors(parent.getChildrenFromFields());
353+
354+
validateUnionData(count, parent);
355+
356+
write(parent.getChild("root"), file);
357+
}
358+
// read
359+
try (
360+
BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE);
361+
FileInputStream fileInputStream = new FileInputStream(file);
362+
ArrowReader arrowReader = new ArrowReader(fileInputStream.getChannel(), readerAllocator);
363+
BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE);
364+
NullableMapVector parent = new NullableMapVector("parent", vectorAllocator, null)
365+
) {
366+
ArrowFooter footer = arrowReader.readFooter();
367+
Schema schema = footer.getSchema();
368+
LOGGER.debug("reading schema: " + schema);
369+
370+
// initialize vectors
371+
372+
NullableMapVector root = parent.addOrGet("root", MinorType.MAP, NullableMapVector.class);
373+
VectorLoader vectorLoader = new VectorLoader(schema, root);
374+
375+
List<ArrowBlock> recordBatches = footer.getRecordBatches();
376+
for (ArrowBlock rbBlock : recordBatches) {
377+
try (ArrowRecordBatch recordBatch = arrowReader.readRecordBatch(rbBlock)) {
378+
vectorLoader.load(recordBatch);
379+
}
380+
validateUnionData(count, parent);
381+
}
382+
}
383+
}
384+
385+
public void validateUnionData(int count, MapVector parent) {
386+
MapReader rootReader = new SingleMapReaderImpl(parent).reader("root");
387+
for (int i = 0; i < count; i++) {
388+
rootReader.setPosition(i);
389+
switch (i % 4) {
390+
case 0:
391+
Assert.assertEquals(i, rootReader.reader("union").readInteger().intValue());
392+
break;
393+
case 1:
394+
Assert.assertEquals(i, rootReader.reader("union").readLong().longValue());
395+
break;
396+
case 2:
397+
Assert.assertEquals(i % 3, rootReader.reader("union").size());
398+
break;
399+
case 3:
400+
NullableTimeStampHolder h = new NullableTimeStampHolder();
401+
rootReader.reader("union").reader("timestamp").read(h);
402+
Assert.assertEquals(i, h.value);
403+
break;
404+
}
405+
}
406+
}
407+
408+
public void writeUnionData(int count, NullableMapVector parent) {
409+
ArrowBuf varchar = allocator.buffer(3);
410+
varchar.readerIndex(0);
411+
varchar.setByte(0, 'a');
412+
varchar.setByte(1, 'b');
413+
varchar.setByte(2, 'c');
414+
varchar.writerIndex(3);
415+
ComplexWriter writer = new ComplexWriterImpl("root", parent);
416+
MapWriter rootWriter = writer.rootAsMap();
417+
IntWriter intWriter = rootWriter.integer("union");
418+
BigIntWriter bigIntWriter = rootWriter.bigInt("union");
419+
ListWriter listWriter = rootWriter.list("union");
420+
MapWriter mapWriter = rootWriter.map("union");
421+
for (int i = 0; i < count; i++) {
422+
switch (i % 4) {
423+
case 0:
424+
intWriter.setPosition(i);
425+
intWriter.writeInt(i);
426+
break;
427+
case 1:
428+
bigIntWriter.setPosition(i);
429+
bigIntWriter.writeBigInt(i);
430+
break;
431+
case 2:
432+
listWriter.setPosition(i);
433+
listWriter.startList();
434+
for (int j = 0; j < i % 3; j++) {
435+
listWriter.varChar().writeVarChar(0, 3, varchar);
436+
}
437+
listWriter.endList();
438+
break;
439+
case 3:
440+
mapWriter.setPosition(i);
441+
mapWriter.start();
442+
mapWriter.timeStamp("timestamp").writeTimeStamp(i);
443+
mapWriter.end();
444+
break;
445+
}
446+
}
447+
writer.setValueCount(count);
448+
varchar.release();
449+
}
342450
}

0 commit comments

Comments
 (0)