@@ -21,8 +21,9 @@ import scala.collection.JavaConverters._
2121import scala .language .implicitConversions
2222
2323import io .netty .buffer .ArrowBuf
24- import org .apache .arrow .memory .RootAllocator
25- import org .apache .arrow .vector .BitVector
24+ import org .apache .arrow .memory .{BaseAllocator , RootAllocator }
25+ import org .apache .arrow .vector ._
26+ import org .apache .arrow .vector .BaseValueVector .BaseMutator
2627import org .apache .arrow .vector .schema .{ArrowFieldNode , ArrowRecordBatch }
2728import org .apache .arrow .vector .types .FloatingPointPrecision
2829import org .apache .arrow .vector .types .pojo .{ArrowType , Field , Schema }
@@ -32,70 +33,17 @@ import org.apache.spark.sql.types._
3233
3334object Arrow {
3435
35- private case class TypeFuncs (getType : () => ArrowType ,
36- fill : ArrowBuf => Unit ,
37- write : (InternalRow , Int , ArrowBuf ) => Unit )
38-
39- private def getTypeFuncs (dataType : DataType ): TypeFuncs = {
40- val err = s " Unsupported data type ${dataType.simpleString}"
41-
36+ private def sparkTypeToArrowType (dataType : DataType ): ArrowType = {
4237 dataType match {
43- case NullType =>
44- TypeFuncs (
45- () => ArrowType .Null .INSTANCE ,
46- (buf : ArrowBuf ) => (),
47- (row : InternalRow , ordinal : Int , buf : ArrowBuf ) => ())
48- case BooleanType =>
49- TypeFuncs (
50- () => ArrowType .Bool .INSTANCE ,
51- (buf : ArrowBuf ) => buf.writeBoolean(false ),
52- (row : InternalRow , ordinal : Int , buf : ArrowBuf ) =>
53- buf.writeBoolean(row.getBoolean(ordinal)))
54- case ShortType =>
55- TypeFuncs (
56- () => new ArrowType .Int (8 * ShortType .defaultSize, true ),
57- (buf : ArrowBuf ) => buf.writeShort(0 ),
58- (row : InternalRow , ordinal : Int , buf : ArrowBuf ) => buf.writeShort(row.getShort(ordinal)))
59- case IntegerType =>
60- TypeFuncs (
61- () => new ArrowType .Int (8 * IntegerType .defaultSize, true ),
62- (buf : ArrowBuf ) => buf.writeInt(0 ),
63- (row : InternalRow , ordinal : Int , buf : ArrowBuf ) => buf.writeInt(row.getInt(ordinal)))
64- case LongType =>
65- TypeFuncs (
66- () => new ArrowType .Int (8 * LongType .defaultSize, true ),
67- (buf : ArrowBuf ) => buf.writeLong(0L ),
68- (row : InternalRow , ordinal : Int , buf : ArrowBuf ) => buf.writeLong(row.getLong(ordinal)))
69- case FloatType =>
70- TypeFuncs (
71- () => new ArrowType .FloatingPoint (FloatingPointPrecision .SINGLE ),
72- (buf : ArrowBuf ) => buf.writeFloat(0f ),
73- (row : InternalRow , ordinal : Int , buf : ArrowBuf ) => buf.writeFloat(row.getFloat(ordinal)))
74- case DoubleType =>
75- TypeFuncs (
76- () => new ArrowType .FloatingPoint (FloatingPointPrecision .DOUBLE ),
77- (buf : ArrowBuf ) => buf.writeDouble(0d ),
78- (row : InternalRow , ordinal : Int , buf : ArrowBuf ) =>
79- buf.writeDouble(row.getDouble(ordinal)))
80- case ByteType =>
81- TypeFuncs (
82- () => new ArrowType .Int (8 , false ),
83- (buf : ArrowBuf ) => buf.writeByte(0 ),
84- (row : InternalRow , ordinal : Int , buf : ArrowBuf ) => buf.writeByte(row.getByte(ordinal)))
85- case StringType =>
86- TypeFuncs (
87- () => ArrowType .Utf8 .INSTANCE ,
88- (buf : ArrowBuf ) => throw new UnsupportedOperationException (err), // TODO
89- (row : InternalRow , ordinal : Int , buf : ArrowBuf ) =>
90- throw new UnsupportedOperationException (err))
91- case StructType (_) =>
92- TypeFuncs (
93- () => ArrowType .Struct .INSTANCE ,
94- (buf : ArrowBuf ) => throw new UnsupportedOperationException (err), // TODO
95- (row : InternalRow , ordinal : Int , buf : ArrowBuf ) =>
96- throw new UnsupportedOperationException (err))
97- case _ =>
98- throw new IllegalArgumentException (err)
38+ case BooleanType => ArrowType .Bool .INSTANCE
39+ case ShortType => new ArrowType .Int (8 * ShortType .defaultSize, true )
40+ case IntegerType => new ArrowType .Int (8 * IntegerType .defaultSize, true )
41+ case LongType => new ArrowType .Int (8 * LongType .defaultSize, true )
42+ case FloatType => new ArrowType .FloatingPoint (FloatingPointPrecision .SINGLE )
43+ case DoubleType => new ArrowType .FloatingPoint (FloatingPointPrecision .DOUBLE )
44+ case ByteType => new ArrowType .Int (8 , false )
45+ case StringType => ArrowType .Utf8 .INSTANCE
46+ case _ => throw new UnsupportedOperationException (s " Unsupported data type: ${dataType}" )
9947 }
10048 }
10149
@@ -110,8 +58,8 @@ object Arrow {
11058 internalRowToArrowBuf(rows, ordinal, field, allocator)
11159 }
11260
113- val buffers = bufAndField.flatMap(_._1).toList.asJava
114- val fieldNodes = bufAndField.flatMap(_._2).toList.asJava
61+ val fieldNodes = bufAndField.flatMap(_._1).toList.asJava
62+ val buffers = bufAndField.flatMap(_._2).toList.asJava
11563
11664 new ArrowRecordBatch (rows.length, fieldNodes, buffers)
11765 }
@@ -123,67 +71,24 @@ object Arrow {
12371 rows : Array [InternalRow ],
12472 ordinal : Int ,
12573 field : StructField ,
126- allocator : RootAllocator ): (Array [ArrowBuf ], Array [ArrowFieldNode ]) = {
74+ allocator : RootAllocator ): (Array [ArrowFieldNode ], Array [ArrowBuf ]) = {
12775 val numOfRows = rows.length
76+ val columnWriter = ColumnWriter (allocator, field.dataType)
77+ columnWriter.init(numOfRows)
78+ var index = 0
12879
129- field.dataType match {
130- case ShortType | IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType =>
131- val validityVector = new BitVector (" validity" , allocator)
132- val validityMutator = validityVector.getMutator
133- validityVector.allocateNew(numOfRows)
134- validityMutator.setValueCount(numOfRows)
135-
136- val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
137- val typeFunc = getTypeFuncs(field.dataType)
138- var nullCount = 0
139- var index = 0
140- while (index < rows.length) {
141- val row = rows(index)
142- if (row.isNullAt(ordinal)) {
143- nullCount += 1
144- validityMutator.set(index, 0 )
145- typeFunc.fill(buf)
146- } else {
147- validityMutator.set(index, 1 )
148- typeFunc.write(row, ordinal, buf)
149- }
150- index += 1
151- }
152-
153- val fieldNode = new ArrowFieldNode (numOfRows, nullCount)
154-
155- (Array (validityVector.getBuffer, buf), Array (fieldNode))
156-
157- case StringType =>
158- val validityVector = new BitVector (" validity" , allocator)
159- val validityMutator = validityVector.getMutator()
160- validityVector.allocateNew(numOfRows)
161- validityMutator.setValueCount(numOfRows)
162-
163- val bufOffset = allocator.buffer((numOfRows + 1 ) * IntegerType .defaultSize)
164- var bytesCount = 0
165- bufOffset.writeInt(bytesCount)
166- val bufValues = allocator.buffer(1024 )
167- var nullCount = 0
168- rows.zipWithIndex.foreach { case (row, index) =>
169- if (row.isNullAt(ordinal)) {
170- nullCount += 1
171- validityMutator.set(index, 0 )
172- bufOffset.writeInt(bytesCount)
173- } else {
174- validityMutator.set(index, 1 )
175- val bytes = row.getUTF8String(ordinal).getBytes
176- bytesCount += bytes.length
177- bufOffset.writeInt(bytesCount)
178- bufValues.writeBytes(bytes)
179- }
180- }
181-
182- val fieldNode = new ArrowFieldNode (numOfRows, nullCount)
183-
184- (Array (validityVector.getBuffer, bufOffset, bufValues),
185- Array (fieldNode))
80+ while (index < numOfRows) {
81+ val row = rows(index)
82+ if (row.isNullAt(ordinal)) {
83+ columnWriter.writeNull()
84+ } else {
85+ columnWriter.write(row, ordinal)
86+ }
87+ index += 1
18688 }
89+
90+ val (arrowFieldNodes, arrowBufs) = columnWriter.finish()
91+ (arrowFieldNodes.toArray, arrowBufs.toArray)
18792 }
18893
18994 private [sql] def schemaToArrowSchema (schema : StructType ): Schema = {
@@ -195,13 +100,158 @@ object Arrow {
195100 val name = sparkField.name
196101 val dataType = sparkField.dataType
197102 val nullable = sparkField.nullable
103+ new Field (name, nullable, sparkTypeToArrowType(dataType), List .empty[Field ].asJava)
104+ }
105+ }
198106
107+ object ColumnWriter {
108+ def apply (allocator : BaseAllocator , dataType : DataType ): ColumnWriter = {
199109 dataType match {
200- case StructType (fields) =>
201- val childrenFields = fields.map(sparkFieldToArrowField).toList.asJava
202- new Field (name, nullable, ArrowType .Struct .INSTANCE , childrenFields)
203- case _ =>
204- new Field (name, nullable, getTypeFuncs(dataType).getType(), List .empty[Field ].asJava)
110+ case BooleanType => new BooleanColumnWriter (allocator)
111+ case ShortType => new ShortColumnWriter (allocator)
112+ case IntegerType => new IntegerColumnWriter (allocator)
113+ case LongType => new LongColumnWriter (allocator)
114+ case FloatType => new FloatColumnWriter (allocator)
115+ case DoubleType => new DoubleColumnWriter (allocator)
116+ case ByteType => new ByteColumnWriter (allocator)
117+ case StringType => new UTF8StringColumnWriter (allocator)
118+ case _ => throw new UnsupportedOperationException (s " Unsupported data type: ${dataType}" )
205119 }
206120 }
207121}
122+
123+ private [sql] trait ColumnWriter {
124+ def init (initialSize : Int ): Unit
125+ def writeNull (): Unit
126+ def write (row : InternalRow , ordinal : Int ): Unit
127+ def finish (): (Seq [ArrowFieldNode ], Seq [ArrowBuf ])
128+ }
129+
130+ /**
131+ * Base class for flat arrow column writer, i.e., column without children.
132+ */
133+ private [sql] abstract class PrimitiveColumnWriter (protected val allocator : BaseAllocator )
134+ extends ColumnWriter {
135+ protected val valueVector : BaseDataValueVector
136+ protected val valueMutator : BaseMutator
137+
138+ protected var count = 0
139+ protected var nullCount = 0
140+
141+ protected def setNull (): Unit
142+ protected def setValue (row : InternalRow , ordinal : Int ): Unit
143+ protected def valueBuffers (): Seq [ArrowBuf ] = valueVector.getBuffers(true ) // TODO: check the flag
144+
145+ override def init (initialSize : Int ): Unit = {
146+ valueVector.allocateNew()
147+ }
148+
149+ override def writeNull (): Unit = {
150+ setNull()
151+ nullCount += 1
152+ count += 1
153+ }
154+
155+ override def write (row : InternalRow , ordinal : Int ): Unit = {
156+ setValue(row, ordinal)
157+ count += 1
158+ }
159+
160+ override def finish (): (Seq [ArrowFieldNode ], Seq [ArrowBuf ]) = {
161+ valueMutator.setValueCount(count)
162+ val fieldNode = new ArrowFieldNode (count, nullCount)
163+ (List (fieldNode), valueBuffers)
164+ }
165+ }
166+
167+ private [sql] class BooleanColumnWriter (allocator : BaseAllocator )
168+ extends PrimitiveColumnWriter (allocator) {
169+ private def bool2int (b : Boolean ): Int = if (b) 1 else 0
170+
171+ override protected val valueVector : NullableBitVector
172+ = new NullableBitVector (" BooleanValue" , allocator)
173+ override protected val valueMutator : NullableBitVector # Mutator = valueVector.getMutator
174+
175+ override def setNull (): Unit = valueMutator.setNull(count)
176+ override def setValue (row : InternalRow , ordinal : Int ): Unit
177+ = valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal)))
178+ }
179+
180+ private [sql] class ShortColumnWriter (allocator : BaseAllocator )
181+ extends PrimitiveColumnWriter (allocator) {
182+ override protected val valueVector : NullableSmallIntVector
183+ = new NullableSmallIntVector (" ShortValue" , allocator)
184+ override protected val valueMutator : NullableSmallIntVector # Mutator = valueVector.getMutator
185+
186+ override def setNull (): Unit = valueMutator.setNull(count)
187+ override def setValue (row : InternalRow , ordinal : Int ): Unit
188+ = valueMutator.setSafe(count, row.getShort(ordinal))
189+ }
190+
191+ private [sql] class IntegerColumnWriter (allocator : BaseAllocator )
192+ extends PrimitiveColumnWriter (allocator) {
193+ override protected val valueVector : NullableIntVector
194+ = new NullableIntVector (" IntValue" , allocator)
195+ override protected val valueMutator : NullableIntVector # Mutator = valueVector.getMutator
196+
197+ override def setNull (): Unit = valueMutator.setNull(count)
198+ override def setValue (row : InternalRow , ordinal : Int ): Unit
199+ = valueMutator.setSafe(count, row.getInt(ordinal))
200+ }
201+
202+ private [sql] class LongColumnWriter (allocator : BaseAllocator )
203+ extends PrimitiveColumnWriter (allocator) {
204+ override protected val valueVector : NullableBigIntVector
205+ = new NullableBigIntVector (" LongValue" , allocator)
206+ override protected val valueMutator : NullableBigIntVector # Mutator = valueVector.getMutator
207+
208+ override def setNull (): Unit = valueMutator.setNull(count)
209+ override def setValue (row : InternalRow , ordinal : Int ): Unit
210+ = valueMutator.setSafe(count, row.getLong(ordinal))
211+ }
212+
213+ private [sql] class FloatColumnWriter (allocator : BaseAllocator )
214+ extends PrimitiveColumnWriter (allocator) {
215+ override protected val valueVector : NullableFloat4Vector
216+ = new NullableFloat4Vector (" FloatValue" , allocator)
217+ override protected val valueMutator : NullableFloat4Vector # Mutator = valueVector.getMutator
218+
219+ override def setNull (): Unit = valueMutator.setNull(count)
220+ override def setValue (row : InternalRow , ordinal : Int ): Unit
221+ = valueMutator.setSafe(count, row.getFloat(ordinal))
222+ }
223+
224+ private [sql] class DoubleColumnWriter (allocator : BaseAllocator )
225+ extends PrimitiveColumnWriter (allocator) {
226+ override protected val valueVector : NullableFloat8Vector
227+ = new NullableFloat8Vector (" DoubleValue" , allocator)
228+ override protected val valueMutator : NullableFloat8Vector # Mutator = valueVector.getMutator
229+
230+ override def setNull (): Unit = valueMutator.setNull(count)
231+ override def setValue (row : InternalRow , ordinal : Int ): Unit
232+ = valueMutator.setSafe(count, row.getDouble(ordinal))
233+ }
234+
235+ private [sql] class ByteColumnWriter (allocator : BaseAllocator )
236+ extends PrimitiveColumnWriter (allocator) {
237+ override protected val valueVector : NullableUInt1Vector
238+ = new NullableUInt1Vector (" ByteValue" , allocator)
239+ override protected val valueMutator : NullableUInt1Vector # Mutator = valueVector.getMutator
240+
241+ override def setNull (): Unit = valueMutator.setNull(count)
242+ override def setValue (row : InternalRow , ordinal : Int ): Unit
243+ = valueMutator.setSafe(count, row.getByte(ordinal))
244+ }
245+
246+ private [sql] class UTF8StringColumnWriter (allocator : BaseAllocator )
247+ extends PrimitiveColumnWriter (allocator) {
248+ override protected val valueVector : NullableVarBinaryVector
249+ = new NullableVarBinaryVector (" UTF8StringValue" , allocator)
250+ override protected val valueMutator : NullableVarBinaryVector # Mutator = valueVector.getMutator
251+
252+ override def setNull (): Unit = valueMutator.setNull(count)
253+ override def setValue (row : InternalRow , ordinal : Int ): Unit = {
254+ val bytes = row.getUTF8String(ordinal).getBytes
255+ valueMutator.setSafe(count, bytes, 0 , bytes.length)
256+ }
257+ }
0 commit comments