Skip to content

Commit bdba357

Browse files
icexellossBryanCutler
authored andcommitted
Implement Arrow column writers
Move column writers to Arrow.scala Add support for more types; Switch to arrow NullableVector closes apache#16
1 parent 5837b38 commit bdba357

File tree

1 file changed

+180
-130
lines changed
  • sql/core/src/main/scala/org/apache/spark/sql

1 file changed

+180
-130
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala

Lines changed: 180 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ import scala.collection.JavaConverters._
2121
import scala.language.implicitConversions
2222

2323
import io.netty.buffer.ArrowBuf
24-
import org.apache.arrow.memory.RootAllocator
25-
import org.apache.arrow.vector.BitVector
24+
import org.apache.arrow.memory.{BaseAllocator, RootAllocator}
25+
import org.apache.arrow.vector._
26+
import org.apache.arrow.vector.BaseValueVector.BaseMutator
2627
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
2728
import org.apache.arrow.vector.types.FloatingPointPrecision
2829
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
@@ -32,70 +33,17 @@ import org.apache.spark.sql.types._
3233

3334
object Arrow {
3435

35-
private case class TypeFuncs(getType: () => ArrowType,
36-
fill: ArrowBuf => Unit,
37-
write: (InternalRow, Int, ArrowBuf) => Unit)
38-
39-
private def getTypeFuncs(dataType: DataType): TypeFuncs = {
40-
val err = s"Unsupported data type ${dataType.simpleString}"
41-
36+
private def sparkTypeToArrowType(dataType: DataType): ArrowType = {
4237
dataType match {
43-
case NullType =>
44-
TypeFuncs(
45-
() => ArrowType.Null.INSTANCE,
46-
(buf: ArrowBuf) => (),
47-
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => ())
48-
case BooleanType =>
49-
TypeFuncs(
50-
() => ArrowType.Bool.INSTANCE,
51-
(buf: ArrowBuf) => buf.writeBoolean(false),
52-
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
53-
buf.writeBoolean(row.getBoolean(ordinal)))
54-
case ShortType =>
55-
TypeFuncs(
56-
() => new ArrowType.Int(8 * ShortType.defaultSize, true),
57-
(buf: ArrowBuf) => buf.writeShort(0),
58-
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeShort(row.getShort(ordinal)))
59-
case IntegerType =>
60-
TypeFuncs(
61-
() => new ArrowType.Int(8 * IntegerType.defaultSize, true),
62-
(buf: ArrowBuf) => buf.writeInt(0),
63-
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeInt(row.getInt(ordinal)))
64-
case LongType =>
65-
TypeFuncs(
66-
() => new ArrowType.Int(8 * LongType.defaultSize, true),
67-
(buf: ArrowBuf) => buf.writeLong(0L),
68-
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeLong(row.getLong(ordinal)))
69-
case FloatType =>
70-
TypeFuncs(
71-
() => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE),
72-
(buf: ArrowBuf) => buf.writeFloat(0f),
73-
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeFloat(row.getFloat(ordinal)))
74-
case DoubleType =>
75-
TypeFuncs(
76-
() => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE),
77-
(buf: ArrowBuf) => buf.writeDouble(0d),
78-
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
79-
buf.writeDouble(row.getDouble(ordinal)))
80-
case ByteType =>
81-
TypeFuncs(
82-
() => new ArrowType.Int(8, false),
83-
(buf: ArrowBuf) => buf.writeByte(0),
84-
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeByte(row.getByte(ordinal)))
85-
case StringType =>
86-
TypeFuncs(
87-
() => ArrowType.Utf8.INSTANCE,
88-
(buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO
89-
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
90-
throw new UnsupportedOperationException(err))
91-
case StructType(_) =>
92-
TypeFuncs(
93-
() => ArrowType.Struct.INSTANCE,
94-
(buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO
95-
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
96-
throw new UnsupportedOperationException(err))
97-
case _ =>
98-
throw new IllegalArgumentException(err)
38+
case BooleanType => ArrowType.Bool.INSTANCE
39+
case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true)
40+
case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true)
41+
case LongType => new ArrowType.Int(8 * LongType.defaultSize, true)
42+
case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
43+
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
44+
case ByteType => new ArrowType.Int(8, false)
45+
case StringType => ArrowType.Utf8.INSTANCE
46+
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
9947
}
10048
}
10149

@@ -110,8 +58,8 @@ object Arrow {
11058
internalRowToArrowBuf(rows, ordinal, field, allocator)
11159
}
11260

113-
val buffers = bufAndField.flatMap(_._1).toList.asJava
114-
val fieldNodes = bufAndField.flatMap(_._2).toList.asJava
61+
val fieldNodes = bufAndField.flatMap(_._1).toList.asJava
62+
val buffers = bufAndField.flatMap(_._2).toList.asJava
11563

11664
new ArrowRecordBatch(rows.length, fieldNodes, buffers)
11765
}
@@ -123,67 +71,24 @@ object Arrow {
12371
rows: Array[InternalRow],
12472
ordinal: Int,
12573
field: StructField,
126-
allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = {
74+
allocator: RootAllocator): (Array[ArrowFieldNode], Array[ArrowBuf]) = {
12775
val numOfRows = rows.length
76+
val columnWriter = ColumnWriter(allocator, field.dataType)
77+
columnWriter.init(numOfRows)
78+
var index = 0
12879

129-
field.dataType match {
130-
case ShortType | IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType =>
131-
val validityVector = new BitVector("validity", allocator)
132-
val validityMutator = validityVector.getMutator
133-
validityVector.allocateNew(numOfRows)
134-
validityMutator.setValueCount(numOfRows)
135-
136-
val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
137-
val typeFunc = getTypeFuncs(field.dataType)
138-
var nullCount = 0
139-
var index = 0
140-
while (index < rows.length) {
141-
val row = rows(index)
142-
if (row.isNullAt(ordinal)) {
143-
nullCount += 1
144-
validityMutator.set(index, 0)
145-
typeFunc.fill(buf)
146-
} else {
147-
validityMutator.set(index, 1)
148-
typeFunc.write(row, ordinal, buf)
149-
}
150-
index += 1
151-
}
152-
153-
val fieldNode = new ArrowFieldNode(numOfRows, nullCount)
154-
155-
(Array(validityVector.getBuffer, buf), Array(fieldNode))
156-
157-
case StringType =>
158-
val validityVector = new BitVector("validity", allocator)
159-
val validityMutator = validityVector.getMutator()
160-
validityVector.allocateNew(numOfRows)
161-
validityMutator.setValueCount(numOfRows)
162-
163-
val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize)
164-
var bytesCount = 0
165-
bufOffset.writeInt(bytesCount)
166-
val bufValues = allocator.buffer(1024)
167-
var nullCount = 0
168-
rows.zipWithIndex.foreach { case (row, index) =>
169-
if (row.isNullAt(ordinal)) {
170-
nullCount += 1
171-
validityMutator.set(index, 0)
172-
bufOffset.writeInt(bytesCount)
173-
} else {
174-
validityMutator.set(index, 1)
175-
val bytes = row.getUTF8String(ordinal).getBytes
176-
bytesCount += bytes.length
177-
bufOffset.writeInt(bytesCount)
178-
bufValues.writeBytes(bytes)
179-
}
180-
}
181-
182-
val fieldNode = new ArrowFieldNode(numOfRows, nullCount)
183-
184-
(Array(validityVector.getBuffer, bufOffset, bufValues),
185-
Array(fieldNode))
80+
while(index < numOfRows) {
81+
val row = rows(index)
82+
if (row.isNullAt(ordinal)) {
83+
columnWriter.writeNull()
84+
} else {
85+
columnWriter.write(row, ordinal)
86+
}
87+
index += 1
18688
}
89+
90+
val (arrowFieldNodes, arrowBufs) = columnWriter.finish()
91+
(arrowFieldNodes.toArray, arrowBufs.toArray)
18792
}
18893

18994
private[sql] def schemaToArrowSchema(schema: StructType): Schema = {
@@ -195,13 +100,158 @@ object Arrow {
195100
val name = sparkField.name
196101
val dataType = sparkField.dataType
197102
val nullable = sparkField.nullable
103+
new Field(name, nullable, sparkTypeToArrowType(dataType), List.empty[Field].asJava)
104+
}
105+
}
198106

107+
object ColumnWriter {
108+
def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = {
199109
dataType match {
200-
case StructType(fields) =>
201-
val childrenFields = fields.map(sparkFieldToArrowField).toList.asJava
202-
new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields)
203-
case _ =>
204-
new Field(name, nullable, getTypeFuncs(dataType).getType(), List.empty[Field].asJava)
110+
case BooleanType => new BooleanColumnWriter(allocator)
111+
case ShortType => new ShortColumnWriter(allocator)
112+
case IntegerType => new IntegerColumnWriter(allocator)
113+
case LongType => new LongColumnWriter(allocator)
114+
case FloatType => new FloatColumnWriter(allocator)
115+
case DoubleType => new DoubleColumnWriter(allocator)
116+
case ByteType => new ByteColumnWriter(allocator)
117+
case StringType => new UTF8StringColumnWriter(allocator)
118+
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
205119
}
206120
}
207121
}
122+
123+
private[sql] trait ColumnWriter {
124+
def init(initialSize: Int): Unit
125+
def writeNull(): Unit
126+
def write(row: InternalRow, ordinal: Int): Unit
127+
def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf])
128+
}
129+
130+
/**
131+
* Base class for flat arrow column writer, i.e., column without children.
132+
*/
133+
private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator)
134+
extends ColumnWriter {
135+
protected val valueVector: BaseDataValueVector
136+
protected val valueMutator: BaseMutator
137+
138+
protected var count = 0
139+
protected var nullCount = 0
140+
141+
protected def setNull(): Unit
142+
protected def setValue(row: InternalRow, ordinal: Int): Unit
143+
protected def valueBuffers(): Seq[ArrowBuf] = valueVector.getBuffers(true) // TODO: check the flag
144+
145+
override def init(initialSize: Int): Unit = {
146+
valueVector.allocateNew()
147+
}
148+
149+
override def writeNull(): Unit = {
150+
setNull()
151+
nullCount += 1
152+
count += 1
153+
}
154+
155+
override def write(row: InternalRow, ordinal: Int): Unit = {
156+
setValue(row, ordinal)
157+
count += 1
158+
}
159+
160+
override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = {
161+
valueMutator.setValueCount(count)
162+
val fieldNode = new ArrowFieldNode(count, nullCount)
163+
(List(fieldNode), valueBuffers)
164+
}
165+
}
166+
167+
private[sql] class BooleanColumnWriter(allocator: BaseAllocator)
168+
extends PrimitiveColumnWriter(allocator) {
169+
private def bool2int(b: Boolean): Int = if (b) 1 else 0
170+
171+
override protected val valueVector: NullableBitVector
172+
= new NullableBitVector("BooleanValue", allocator)
173+
override protected val valueMutator: NullableBitVector#Mutator = valueVector.getMutator
174+
175+
override def setNull(): Unit = valueMutator.setNull(count)
176+
override def setValue(row: InternalRow, ordinal: Int): Unit
177+
= valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal)))
178+
}
179+
180+
private[sql] class ShortColumnWriter(allocator: BaseAllocator)
181+
extends PrimitiveColumnWriter(allocator) {
182+
override protected val valueVector: NullableSmallIntVector
183+
= new NullableSmallIntVector("ShortValue", allocator)
184+
override protected val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator
185+
186+
override def setNull(): Unit = valueMutator.setNull(count)
187+
override def setValue(row: InternalRow, ordinal: Int): Unit
188+
= valueMutator.setSafe(count, row.getShort(ordinal))
189+
}
190+
191+
private[sql] class IntegerColumnWriter(allocator: BaseAllocator)
192+
extends PrimitiveColumnWriter(allocator) {
193+
override protected val valueVector: NullableIntVector
194+
= new NullableIntVector("IntValue", allocator)
195+
override protected val valueMutator: NullableIntVector#Mutator = valueVector.getMutator
196+
197+
override def setNull(): Unit = valueMutator.setNull(count)
198+
override def setValue(row: InternalRow, ordinal: Int): Unit
199+
= valueMutator.setSafe(count, row.getInt(ordinal))
200+
}
201+
202+
private[sql] class LongColumnWriter(allocator: BaseAllocator)
203+
extends PrimitiveColumnWriter(allocator) {
204+
override protected val valueVector: NullableBigIntVector
205+
= new NullableBigIntVector("LongValue", allocator)
206+
override protected val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator
207+
208+
override def setNull(): Unit = valueMutator.setNull(count)
209+
override def setValue(row: InternalRow, ordinal: Int): Unit
210+
= valueMutator.setSafe(count, row.getLong(ordinal))
211+
}
212+
213+
private[sql] class FloatColumnWriter(allocator: BaseAllocator)
214+
extends PrimitiveColumnWriter(allocator) {
215+
override protected val valueVector: NullableFloat4Vector
216+
= new NullableFloat4Vector("FloatValue", allocator)
217+
override protected val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator
218+
219+
override def setNull(): Unit = valueMutator.setNull(count)
220+
override def setValue(row: InternalRow, ordinal: Int): Unit
221+
= valueMutator.setSafe(count, row.getFloat(ordinal))
222+
}
223+
224+
private[sql] class DoubleColumnWriter(allocator: BaseAllocator)
225+
extends PrimitiveColumnWriter(allocator) {
226+
override protected val valueVector: NullableFloat8Vector
227+
= new NullableFloat8Vector("DoubleValue", allocator)
228+
override protected val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator
229+
230+
override def setNull(): Unit = valueMutator.setNull(count)
231+
override def setValue(row: InternalRow, ordinal: Int): Unit
232+
= valueMutator.setSafe(count, row.getDouble(ordinal))
233+
}
234+
235+
private[sql] class ByteColumnWriter(allocator: BaseAllocator)
236+
extends PrimitiveColumnWriter(allocator) {
237+
override protected val valueVector: NullableUInt1Vector
238+
= new NullableUInt1Vector("ByteValue", allocator)
239+
override protected val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator
240+
241+
override def setNull(): Unit = valueMutator.setNull(count)
242+
override def setValue(row: InternalRow, ordinal: Int): Unit
243+
= valueMutator.setSafe(count, row.getByte(ordinal))
244+
}
245+
246+
private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
247+
extends PrimitiveColumnWriter(allocator) {
248+
override protected val valueVector: NullableVarBinaryVector
249+
= new NullableVarBinaryVector("UTF8StringValue", allocator)
250+
override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator
251+
252+
override def setNull(): Unit = valueMutator.setNull(count)
253+
override def setValue(row: InternalRow, ordinal: Int): Unit = {
254+
val bytes = row.getUTF8String(ordinal).getBytes
255+
valueMutator.setSafe(count, bytes, 0, bytes.length)
256+
}
257+
}

0 commit comments

Comments
 (0)